# from src.eval import build_compile_cache_with_capturing
import subprocess
import os
import ninja

from src.utils import set_gpu_arch
from src.eval import build_compile_cache_with_capturing

################################################################################
# Test for checking if we can capture nvcc errors
################################################################################

correct_kernel_code = "import torch\nimport torch.nn as nn\nfrom torch.utils.cpp_extension import load_inline\n\n# Define the custom CUDA kernel for 4D tensor-matrix multiplication\ntensor_matmul_source = \"\"\"\n#include <torch/extension.h>\n#include <cuda_runtime.h>\n\n__global__ void tensor_matmul_kernel(const float* A, const float* B, float* C, int b, int i, int j, int l, int k) {\n    int idx_b = blockIdx.x;\n    int idx_i = blockIdx.y;\n    int idx_j = blockIdx.z;\n    int idx_k = threadIdx.x;\n\n    if (idx_b < b && idx_i < i && idx_j < j && idx_k < k) {\n        float sum = 0.0f;\n        for (int idx_l = 0; idx_l < l; ++idx_l) {\n            sum += A[idx_b * i * j * l + idx_i * j * l + idx_j * l + idx_l] * B[idx_l * k + idx_k];\n        }\n        C[idx_b * i * j * k + idx_i * j * k + idx_j * k + idx_k] = sum;\n    }\n}\n\ntorch::Tensor tensor_matmul_cuda(torch::Tensor A, torch::Tensor B) {\n    int b = A.size(0);\n    int i = A.size(1);\n    int j = A.size(2);\n    int l = A.size(3);\n    int k = B.size(1);\n\n    auto C = torch::zeros({b, i, j, k}, A.options());\n\n    dim3 blocks(b, i, j);\n    int threads = k;\n\n    tensor_matmul_kernel<<<blocks, threads>>>(A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), b, i, j, l, k);\n\n    return C;\n}\n\"\"\"\n\ntensor_matmul_cpp_source = (\n    \"torch::Tensor tensor_matmul_cuda(torch::Tensor A, torch::Tensor B);\"\n)\n\n# Compile the inline CUDA code for 4D tensor-matrix multiplication\ntensor_matmul = load_inline(\n    name=\"tensor_matmul\",\n    cpp_sources=tensor_matmul_cpp_source,\n    cuda_sources=tensor_matmul_source,\n    functions=[\"tensor_matmul_cuda\"],\n    verbose=True,\n    extra_cflags=[\"\"],\n    extra_ldflags=[\"\"],\n)\n\n\nclass ModelNew(nn.Module):\n    def __init__(self):\n        super(ModelNew, self).__init__()\n        self.tensor_matmul = tensor_matmul\n\n    def forward(self, A, B):\n        return self.tensor_matmul.tensor_matmul_cuda(A, B)"

faulty_kernel_code = 'import torch\nimport torch.nn as nn\nfrom torch.utils.cpp_extension import load_inline\n\n# Define the custom CUDA kernel for Max Pooling 3D\nmaxpool3d_source = """\n#include <torch/extension.h>\n#include <cuda_runtime.h>\n\n__global__ void maxpool3d_kernel(const float* input, float* output, int* indices, \n                                 int batch_size, int channels, int dim1, int dim2, int dim3,\n                                 int kernel_size, int stride, int padding, int dilation,\n                                 int out_dim1, int out_dim2, int out_dim3) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int b = idx / (channels * out_dim1 * out_dim2 * out_dim3);\n    int c = (idx / (out_dim1 * out_dim2 * out_dim3)) % channels;\n    int d1 = (idx / (out_dim2 * out_dim3)) % out_dim1;\n    int d2 = (idx / out_dim3) % out_dim2;\n    int d3 = idx % out_dim3;\n\n    if (b < batch_size && c < channels && d1 < out_dim1 && d2 < out_dim2 && d3 < out_dim3) {\n        float max_val = -FLT_MAX;\n        int max_idx = -1;\n\n        for (int k1 = 0; k1 < kernel_size; ++k1) {\n            for (int k2 = 0; k2 < kernel_size; ++k2) {\n                for (int k3 = 0; k3 < kernel_size; ++k3) {\n                    int in_d1 = d1 * stride - padding + k1 * dilation;\n                    int in_d2 = d2 * stride - padding + k2 * dilation;\n                    int in_d3 = d3 * stride - padding + k3 * dilation;\n\n                    if (in_d1 >= 0 && in_d1 < dim1 && in_d2 >= 0 && in_d2 < dim2 && in_d3 >= 0 && in_d3 < dim3) {\n                        int in_idx = ((b * channels + c) * dim1 + in_d1) * dim2 * dim3 + in_d2 * dim3 + in_d3;\n                        float val = input[in_idx];\n                        if (val > max_val) {\n                            max_val = val;\n                            max_idx = in_idx;\n                        }\n                    }\n                }\n            }\n        }\n\n        output[idx] = max_val;\n        if (indices != nullptr) {\n            indices[idx] = max_idx;\n        }\n    }\n}\n\ntorch::Tensor maxpool3d_cuda(torch::Tensor input, int kernel_size, int stride, int padding, int dilation, bool return_indices) {\n    int batch_size = input.size(0);\n    int channels = input.size(1);\n    int dim1 = input.size(2);\n    int dim2 = input.size(3);\n    int dim3 = input.size(4);\n\n    int out_dim1 = (dim1 + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;\n    int out_dim2 = (dim2 + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;\n    int out_dim3 = (dim3 + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;\n\n    auto output = torch::zeros({batch_size, channels, out_dim1, out_dim2, out_dim3}, input.options());\n    auto indices = return_indices ? torch::zeros({batch_size, channels, out_dim1, out_dim2, out_dim3}, input.options().dtype(torch::kInt32)) : torch::Tensor();\n\n    int size = batch_size * channels * out_dim1 * out_dim2 * out_dim3;\n    const int block_size = 256;\n    const int num_blocks = (size + block_size - 1) / block_size;\n\n    maxpool3d_kernel<<<num_blocks, block_size>>>(input.data_ptr<float>(), output.data_ptr<float>(), \n                                                 return_indices ? indices.data_ptr<int>() : nullptr,\n                                                 batch_size, channels, dim1, dim2, dim3,\n                                                 kernel_size, stride, padding, dilation,\n                                                 out_dim1, out_dim2, out_dim3);\n\n    if (return_indices) {\n        return torch::make_tuple(output, indices);\n    } else {\n        return output;\n    }\n}\n"""\n\nmaxpool3d_cpp_source = (\n    "torch::Tensor maxpool3d_cuda(torch::Tensor input, int kernel_size, int stride, int padding, int dilation, bool return_indices);"\n)\n\n# Compile the inline CUDA code for Max Pooling 3D\nmaxpool3d = load_inline(\n    name="maxpool3d",\n    cpp_sources=maxpool3d_cpp_source,\n    cuda_sources=maxpool3d_source,\n    functions=["maxpool3d_cuda"],\n    verbose=True,\n    extra_cflags=[""],\n    extra_ldflags=[""],\n)\n\n\nclass ModelNew(nn.Module):\n    """\n    Optimized model that performs Max Pooling 3D using custom CUDA kernels.\n    """\n    def __init__(self, kernel_size: int, stride: int = None, padding: int = 0, dilation: int = 1, return_indices: bool = False, ceil_mode: bool = False):\n        """\n        Initializes the Max Pooling 3D layer.\n\n        Args:\n            kernel_size (int): Size of the kernel for the max pooling operation.\n            stride (int, optional): Stride of the pooling operation. Defaults to None, which means stride is equal to kernel_size.\n            padding (int, optional): Padding applied to the input tensor. Defaults to 0.\n            dilation (int, optional): Spacing between kernel elements. Defaults to 1.\n            return_indices (bool, optional): Whether to return indices of the maximum values. Defaults to False.\n            ceil_mode (bool, optional): When True, the output size is ceil(input_size / stride) instead of floor. Defaults to False.\n        """\n        super(ModelNew, self).__init__()\n        self.kernel_size = kernel_size\n        self.stride = stride if stride is not None else kernel_size\n        self.padding = padding\n        self.dilation = dilation\n        self.return_indices = return_indices\n        self.ceil_mode = ceil_mode\n        self.maxpool3d = maxpool3d\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        """\n        Applies Max Pooling 3D to the input tensor using custom CUDA kernels.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, channels, dim1, dim2, dim3).\n\n        Returns:\n            torch.Tensor: Output tensor with Max Pooling 3D applied.\n        """\n        return self.maxpool3d.maxpool3d_cuda(x, self.kernel_size, self.stride, self.padding, self.dilation, self.return_indices)'


set_gpu_arch(["Ada"]) # replace with whatever device architecthre you have

test_build_dir = "test_build_dir"

print("Testing Correct Kernel Code")
status, stdout, err = build_compile_cache_with_capturing(correct_kernel_code, verbose=False, build_dir=test_build_dir)
print("status: ", status)
print("stdout: ", stdout)
print("err: ", err)
assert status == 0, "Correct Code should compile"

print("Testing Faulty Kernel Code")
status, stdout, err = build_compile_cache_with_capturing(faulty_kernel_code, verbose=False, build_dir=test_build_dir)
print("status: ", status)
print("stdout: ", stdout)
print("err: ", err)
assert status != 0, "Faulty Code should not compile"
assert len(stdout) > 0, "stdout should not be empty"
assert len(err) > 0, "err should not be empty"

print("~~~~TEST PASSED~~~~")
